#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import itertools
import os

PREFIXES = {
    "dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4",
    "imma": "conv_bias_int8_implicit_gemm",
}

ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")}

BIASES = {
    1: ("PerElementBiasVisitor", "_per_elem"),
    2: ("PerChannelBiasVisitor", "_per_chan"),
}

SUFFIXES = {
    "dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"],
    "imma": [
        "_imma16x16x16_cdiv4hwn4",
        "_imma8x32x16_cdiv4hwn4",
        "_imma32x8x16_cdiv4hwn4",
        "_imma16x16x16_cdiv4hwn4_reorder_filter",
        "_imma8x32x16_cdiv4hwn4_reorder_filter",
        "_imma32x8x16_cdiv4hwn4_reorder_filter",
        "_imma16x16x16_cdiv4hwn4_unroll_width",
        "_imma8x32x16_cdiv4hwn4_unroll_width",
        "_imma32x8x16_cdiv4hwn4_unroll_width",
    ],
}


def main():
    parser = argparse.ArgumentParser(
        description="generate cuda conv bias (dp4a/imma) kern impl files",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--type",
        type=str,
        choices=["dp4a", "imma"],
        default="dp4a",
        help="generate cuda conv bias kernel file",
    )
    parser.add_argument("output", help="output directory")
    args = parser.parse_args()

    if not os.path.isdir(args.output):
        os.makedirs(args.output)

    inst = """
template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, 
        IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>>>(
        const int8_t* d_src, 
        const int8_t* d_filter, 
        BIAS bias, 
        IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>> epilogue, 
        const ConvParam& param, 
        float alpha, 
        float beta, 
        cudaStream_t stream);"""

    for suffix in SUFFIXES[args.type]:
        for _, act in ACTIVATIONS.items():
            prefix = PREFIXES[args.type]
            bias = BIASES[2]
            fname = "{}{}{}{}.cu".format(prefix, suffix, bias[1], act[1])
            fname = os.path.join(args.output, fname)
            with open(fname, "w") as fout:
                w = lambda s: print(s, file=fout)
                w("// generated by gen_cuda_conv_bias_kern_impls.py")
                cur_inst = (
                    inst.replace("PREFIX", prefix)
                    .replace("SUFFIX", suffix)
                    .replace("BIAS", bias[0])
                    .replace("ACTIVATION", act[0])
                )
                w('#include "../{}{}.cuinl"'.format(prefix, suffix))
                w(cur_inst)

            print("generated {}".format(fname))
    os.utime(args.output)


if __name__ == "__main__":
    main()
